Useful functions

library(tidyverse)
myggplot <- function(..., coeff = 1) {
  bigstatsr:::MY_THEME(ggplot(...), coeff = coeff)
} 
plot_results <- function(results, y, ylab = y) {
  
  dist <- "Distribution\nof effects"
  
  myggplot(results) +
    geom_boxplot(aes_string("method", y, color = "par.dist", 
                            fill = "par.dist"), alpha = 0.3) + 
    theme(axis.text.x = element_text(angle = 45, hjust = 1)) +
    facet_grid(par.model ~ par.causal) +
    theme(strip.text.x = element_text(size = rel(2)),
          strip.text.y = element_text(size = rel(2))) +
    labs(x = "Method", y = ylab, fill = dist, color = dist)
}
boot <- function(x, n = 1e5, f = mean) {
  sd(replicate(n, f(sample(x, replace = TRUE))))
}

Results with T-Trees

top10 <- 1:110
top20 <- 1:220

# Put all results in a single tibble
results <- list.files("results1", full.names = TRUE) %>%
  map_dfr(~readRDS(.x)) %>%
  as_tibble() %>%
  mutate(
    par.causal = factor(map_chr(par.causal, ~paste(.x[1], .x[2], sep = " in ")),
                        levels = c("30 in HLA", paste(3 * 10^(1:3), "in all"))),
    AUC = map_dbl(eval, ~bigstatsr::AUC(.x[, 1], .x[, 2])),
    percCases10 = map_dbl(eval, ~mean(.x[order(.x[, 1], decreasing = TRUE)[top10], 2])),
    percCases20 = map_dbl(eval, ~mean(.x[order(.x[, 1], decreasing = TRUE)[top20], 2]))
  )
results %>%
  filter(method %in% c("T-Trees", "logit-simple")) %>%
  group_by_at(c(vars(starts_with("par")), "method")) %>%
  summarise_at(c("timing", "nb.preds", "AUC", "percCases10", "percCases20"), mean) %>%
  print(n = Inf)
## # A tibble: 32 x 10
## # Groups:   par.causal, par.dist, par.h2, par.model [?]
##     par.causal par.dist par.h2 par.model       method     timing nb.preds       AUC percCases10 percCases20
##         <fctr>    <chr>  <dbl>     <chr>        <chr>      <dbl>    <dbl>     <dbl>       <dbl>       <dbl>
##  1   30 in HLA gaussian    0.8     fancy logit-simple   509.9386    793.6 0.9001502   0.9272727   0.8245455
##  2   30 in HLA gaussian    0.8     fancy      T-Trees  2314.0056   3300.8 0.8177493   0.8000000   0.6818182
##  3   30 in HLA gaussian    0.8    simple logit-simple   489.3918    948.0 0.9424040   0.9800000   0.8809091
##  4   30 in HLA gaussian    0.8    simple      T-Trees  2410.5636   5768.6 0.8347770   0.8418182   0.6990909
##  5   30 in HLA  laplace    0.8     fancy logit-simple   500.2442    966.2 0.9008663   0.9054545   0.8272727
##  6   30 in HLA  laplace    0.8     fancy      T-Trees  2805.8358   4542.8 0.8607568   0.8127273   0.7427273
##  7   30 in HLA  laplace    0.8    simple logit-simple   500.8506    722.0 0.9302174   0.9636364   0.8681818
##  8   30 in HLA  laplace    0.8    simple      T-Trees  1644.6880   1995.0 0.8481380   0.8509091   0.7300000
##  9   30 in all gaussian    0.8     fancy logit-simple   518.2830    655.6 0.8775273   0.8745455   0.7709091
## 10   30 in all gaussian    0.8     fancy      T-Trees  3483.2610   7330.4 0.8502545   0.8563636   0.7272727
## 11   30 in all gaussian    0.8    simple logit-simple   505.6132    649.0 0.9323514   0.9672727   0.8918182
## 12   30 in all gaussian    0.8    simple      T-Trees  2922.9090   5472.8 0.8310516   0.8490909   0.7327273
## 13   30 in all  laplace    0.8     fancy logit-simple   510.6026    999.2 0.8932313   0.9036364   0.8118182
## 14   30 in all  laplace    0.8     fancy      T-Trees  4265.4018   7901.2 0.8457449   0.8200000   0.7318182
## 15   30 in all  laplace    0.8    simple logit-simple   496.2482    694.8 0.9249049   0.9636364   0.8700000
## 16   30 in all  laplace    0.8    simple      T-Trees  2274.9888   4626.0 0.8322151   0.8109091   0.7118182
## 17  300 in all gaussian    0.8     fancy logit-simple   572.4542   1677.0 0.7796750   0.7254545   0.6481818
## 18  300 in all gaussian    0.8     fancy      T-Trees 12507.7136  32797.4 0.6265848   0.5454545   0.4581818
## 19  300 in all gaussian    0.8    simple logit-simple   554.6804   2396.6 0.8527230   0.8600000   0.7381818
## 20  300 in all gaussian    0.8    simple      T-Trees 10242.1552  29357.4 0.6003698   0.4181818   0.4009091
## 21  300 in all  laplace    0.8     fancy logit-simple   552.4216   1785.4 0.8259872   0.7981818   0.7081818
## 22  300 in all  laplace    0.8     fancy      T-Trees  7936.3082  18462.2 0.7020579   0.5854545   0.5309091
## 23  300 in all  laplace    0.8    simple logit-simple   541.8522   2170.2 0.8604749   0.8545455   0.7381818
## 24  300 in all  laplace    0.8    simple      T-Trees  6510.7720  18024.0 0.6535618   0.5381818   0.4709091
## 25 3000 in all gaussian    0.8     fancy logit-simple   604.9040   5401.0 0.5738792   0.3945455   0.3727273
## 26 3000 in all gaussian    0.8     fancy      T-Trees 16291.6926  40748.6 0.5067176   0.2981818   0.2954545
## 27 3000 in all gaussian    0.8    simple logit-simple   594.6762   6442.2 0.5906757   0.4309091   0.4100000
## 28 3000 in all gaussian    0.8    simple      T-Trees 13911.6304  40190.8 0.5118337   0.3127273   0.3172727
## 29 3000 in all  laplace    0.8     fancy logit-simple   593.8938   3416.2 0.6326039   0.5018182   0.4536364
## 30 3000 in all  laplace    0.8     fancy      T-Trees 17154.7194  40435.0 0.5151644   0.3036364   0.3136364
## 31 3000 in all  laplace    0.8    simple logit-simple   585.4728   5191.2 0.6329355   0.4927273   0.4754545
## 32 3000 in all  laplace    0.8    simple      T-Trees 14272.6158  40443.8 0.5019465   0.3109091   0.3027273
ttrees_vs_logit <- filter(results, method %in% c("T-Trees", "logit-simple"))

p_list <- list(
  plot_results(ttrees_vs_logit, "timing", "Timing (in seconds)") +
    scale_y_continuous(breaks = 0:10 * 2000, minor_breaks = NULL),
  plot_results(ttrees_vs_logit, "nb.preds", "Number of predictors (log-scale)") +
    scale_y_log10(breaks = c(10^(0:7), 3 * 10^(0:7)), minor_breaks = NULL,
                  labels = scales::comma_format()),
  plot_results(ttrees_vs_logit, "AUC") +
    scale_y_continuous(breaks = 0:10 / 10, minor_breaks = c(0:9 + 0.5) / 10),
  plot_results(ttrees_vs_logit, "percCases10", "Percentage of cases in top 10%") +
    scale_y_continuous(breaks = 0:10 / 10, minor_breaks = c(0:9 + 0.5) / 10)
)

lapply(p_list, function(p) p + theme(legend.position = "none")) %>%
  cowplot::plot_grid(plotlist = ., ncol = 2, align = "hv", scale = 0.9,
                     labels = LETTERS[1:4], label_size = 25) %>%
  cowplot::plot_grid(cowplot::get_legend(p_list[[1]]),
                     rel_widths = c(1, 0.15))
Results of T-Trees vs penalized logistic regression. **A.** Timing (in seconds). **B.** Number of predictors of the model. **C.** AUC. **D.** Percentage of cases in the 10% largest scores.

Results of T-Trees vs penalized logistic regression. A. Timing (in seconds). B. Number of predictors of the model. C. AUC. D. Percentage of cases in the 10% largest scores.

ggsave("figures/ttrees.pdf", scale = 1/90, width = 1580, height = 1070)

Results without T-Trees

Correlation between predictive performance measures

top10 <- 1:110
top20 <- 1:220

# Put all results in a single tibble
results2 <- list.files("results2", full.names = TRUE) %>%
  map_dfr(~readRDS(.x)) %>%
  as_tibble() %>%
  mutate(
    par.causal = factor(map_chr(par.causal, ~paste(.x[1], .x[2], sep = " in ")),
                        levels = c("30 in HLA", paste(3 * 10^(1:3), "in all"))),
    AUC = map_dbl(eval, ~bigstatsr::AUC(.x[, 1], .x[, 2])),
    percCases10 = map_dbl(eval, ~mean(.x[order(.x[, 1], decreasing = TRUE)[top10], 2])),
    percCases20 = map_dbl(eval, ~mean(.x[order(.x[, 1], decreasing = TRUE)[top20], 2]))
  )
H2 <- 0.8
cowplot::plot_grid(
  results2 %>%
    filter(par.h2 == H2) %>%
    myggplot(aes(AUC, percCases10, color = par.dist)) +
    geom_point() +
    geom_smooth(method = "lm") +
    theme(legend.position = c(0.8, 0.2)) +
    labs(y = "Percentage of cases in top 10%", 
         color = "Distribution\nof effects"),
  results2 %>%
    filter(par.h2 == H2) %>%
    myggplot(aes(AUC, percCases20, color = par.dist)) +
    geom_point() +
    geom_smooth(method = "lm") +
    theme(legend.position = c(0.8, 0.2)) +
    labs(y = "Percentage of cases in top 20%", 
         color = "Distribution\nof effects"),
  results2 %>%
    filter(par.h2 == H2) %>%
    myggplot(aes(AUC, percCases10, color = par.model)) +
    geom_point() +
    geom_smooth(method = "lm") +
    theme(legend.position = c(0.8, 0.2)) +
    scale_colour_brewer(palette = "Set1") +
    labs(y = "Percentage of cases in top 10%", 
         color = "Model"),
  results2 %>%
    filter(par.h2 == H2) %>%
    myggplot(aes(AUC, percCases20, color = par.model)) +
    geom_point() +
    geom_smooth(method = "lm") +
    theme(legend.position = c(0.8, 0.2)) +
    scale_colour_brewer(palette = "Set1") +
    labs(y = "Percentage of cases in top 20%", 
         color = "Model"),
  labels = LETTERS[1:4], align = "hv", label_size = 25, scale = 0.95
)
Percentage of cases in the 2 highest deciles of PRSs as a function of AUC.

Percentage of cases in the 2 highest deciles of PRSs as a function of AUC.

ggsave("figures/AUC-corr.pdf", scale = 1/90, width = 1300, height = 950)

Results of AUC

results2 %>%
  filter(par.h2 == 0.8) %>%
  plot_results("AUC") +
  geom_hline(yintercept = 0.94, color = "blue", linetype = 3) +
  scale_y_continuous(breaks = 0:10 / 10, minor_breaks = c(0:9 + 0.5) / 10)
All AUC results for h2=0.8 and all chromosomes

All AUC results for h2=0.8 and all chromosomes

results2 %>%
  filter(par.h2 == 0.8) %>%
  group_by_at(c(vars(starts_with("par")), "num.simu")) %>%
  mutate(AUC_rel = AUC / AUC[method == "logit-simple"]) %>%
  plot_results(y = "AUC_rel", ylab = "Relative AUC / 'logit-simple'") +
  scale_y_continuous(breaks = 0:10 / 10, minor_breaks = c(0:9 + 0.5) / 10)
All relative AUC results for h2=0.8 and all chromosomes

All relative AUC results for h2=0.8 and all chromosomes

results2 %>%
  filter(par.h2 == 0.5) %>%
  plot_results("AUC") +
  scale_y_continuous(breaks = 0:10 / 10, minor_breaks = c(0:9 + 0.5) / 10)
All results for h2=0.5 and all chromosomes

All results for h2=0.5 and all chromosomes

# represent h2=0.8 as a function of h2=0.5
results2 %>%
  select(starts_with("par"), method, AUC) %>%
  group_by(par.causal, par.dist, par.model, method, par.h2) %>%
  summarise(AUC = mean(AUC)) %>%
  spread(par.h2, AUC) %>%
  myggplot(aes(`0.5`, `0.8`, color = method)) +
  geom_smooth(size = 2, alpha = 0.2) +
  geom_point(size = 3)
## `geom_smooth()` using method = 'loess'
Results of AUC for all combination of parameters and methods when h2=0.5 and h2=0.8

Results of AUC for all combination of parameters and methods when h2=0.5 and h2=0.8

results2 %>%
  select(starts_with("par"), method, AUC) %>%
  group_by(par.causal, par.dist, par.model, method, par.h2) %>%
  summarise(AUC = mean(AUC)) %>%
  spread(par.h2, AUC) %>%
  with(cor(`0.5`, `0.8`))
## [1] 0.9875672
results2 %>%
  filter(par.dist == "laplace", par.h2 == 0.8, par.model == "simple",
         method %in% c("logit-simple", "PRS-max")) %>%
  group_by(par.causal, method) %>%
  summarise(AUC_mean = mean(AUC), AUC_boot = boot(AUC, 1e5, mean)) %>%
  myggplot(aes(par.causal, AUC_mean, fill = method, color = method)) +
  geom_hline(yintercept = 0.5, linetype = 2) +
  geom_bar(stat = "identity", alpha = 0.3, position=position_dodge()) +
  geom_errorbar(aes(ymin = AUC_mean - 2 * AUC_boot, ymax = AUC_mean + 2 * AUC_boot),
                position=position_dodge(width=0.9), color = "black", width = 0.2) +
  scale_y_continuous(breaks = 0:10 / 10, minor_breaks = 0:10 / 10 + 0.05)
Main results: logit simple vs PRS max

Main results: logit simple vs PRS max

results2 %>%
  filter(par.dist == "laplace", par.h2 == 0.8, par.model == "simple",
         grepl("PRS", method)) %>%
  group_by(par.causal, method) %>%
  summarise(AUC_mean = mean(AUC),
            AUC_boot = boot(AUC, 1e5, mean)) %>%
  myggplot(aes(par.causal, AUC_mean, fill = method, color = method)) +
  geom_hline(yintercept = 0.5, linetype = 2) +
  geom_bar(stat = "identity", alpha = 0.3, position=position_dodge()) +
  geom_errorbar(aes(ymin = AUC_mean - 2 * AUC_boot, ymax = AUC_mean + 2 * AUC_boot),
                position=position_dodge(width=0.9), color = "black", width = 0.2) +
  scale_y_continuous(breaks = 0:10 / 10, minor_breaks = 0:10 / 10 + 0.05)
Main results: all PRS

Main results: all PRS

Results for chromosome 6

# Put all results in a single tibble
results3 <- list.files("results3", full.names = TRUE) %>%
  map_dfr(~readRDS(.x)) %>%
  as_tibble() %>%
  mutate(
    par.causal = factor(map_chr(par.causal, ~paste(.x[1], .x[2], sep = " in ")),
                        levels = c("30 in HLA", paste(3 * 10^(1:3), "in all"))),
    AUC = map_dbl(eval, ~bigstatsr::AUC(.x[, 1], .x[, 2]))
  )
results3 %>%
  filter(par.h2 == 0.8) %>%
  plot_results("AUC") +
  geom_hline(yintercept = 0.94, color = "blue", linetype = 3) +
  scale_y_continuous(breaks = 0:10 / 10, minor_breaks = c(0:9 + 0.5) / 10)
All AUC results for h2=0.8 and chromosome 6

All AUC results for h2=0.8 and chromosome 6

results3 %>%
  filter(par.h2 == 0.8) %>%
  group_by_at(c(vars(starts_with("par")), "num.simu")) %>%
  mutate(AUC_rel = AUC / AUC[method == "logit-simple"]) %>%
  plot_results(y = "AUC_rel", ylab = "Relative AUC / 'logit-simple'") +
  scale_y_continuous(breaks = 0:10 / 10, minor_breaks = c(0:9 + 0.5) / 10)
All relative AUC results for h2=0.8 and chromosome 6

All relative AUC results for h2=0.8 and chromosome 6

results3 %>%
  filter(par.dist == "laplace", par.h2 == 0.8, par.model == "simple",
         method %in% c("logit-simple", "PRS-max")) %>%
  group_by(par.causal, method) %>%
  summarise(AUC_mean = mean(AUC), AUC_boot = boot(AUC, 1e5, mean)) %>%
  myggplot(aes(par.causal, AUC_mean, fill = method, color = method)) +
  geom_hline(yintercept = 0.5, linetype = 2) +
  geom_bar(stat = "identity", alpha = 0.3, position=position_dodge()) +
  geom_errorbar(aes(ymin = AUC_mean - 2 * AUC_boot, ymax = AUC_mean + 2 * AUC_boot),
                position=position_dodge(width=0.9), color = "black", width = 0.2) +
  scale_y_continuous(breaks = 0:10 / 10, minor_breaks = 0:10 / 10 + 0.05)
Main results: logit simple vs PRS max for chromosome 6

Main results: logit simple vs PRS max for chromosome 6

results3 %>%
  filter(par.dist == "laplace", par.h2 == 0.8, par.model == "simple",
         grepl("PRS", method)) %>%
  group_by(par.causal, method) %>%
  summarise(AUC_mean = mean(AUC),
            AUC_boot = boot(AUC, 1e5, mean)) %>%
  myggplot(aes(par.causal, AUC_mean, fill = method, color = method)) +
  geom_hline(yintercept = 0.5, linetype = 2) +
  geom_bar(stat = "identity", alpha = 0.3, position=position_dodge()) +
  geom_errorbar(aes(ymin = AUC_mean - 2 * AUC_boot, ymax = AUC_mean + 2 * AUC_boot),
                position=position_dodge(width=0.9), color = "black", width = 0.2) +
  scale_y_continuous(breaks = 0:10 / 10, minor_breaks = 0:10 / 10 + 0.05)
Main results: all PRS for chromosome 6

Main results: all PRS for chromosome 6

Prediction on Celiac

knitr::include_graphics("figures/celiac-man.png")
Manhanttan plot for Celiac

Manhanttan plot for Celiac

knitr::include_graphics("figures/celiac-regpath.png")
Regularization paths for the methods. For LR, line in the result given by CMSA.

Regularization paths for the methods. For LR, line in the result given by CMSA.

Misc

results2 %>%
  filter(
    # method %in% c("logit-simple", "PRS-max"),
    grepl("all", par.causal)
  ) %>%
  mutate(M = readr::parse_number(par.causal)) %>%
  group_by(M, par.dist, par.h2, par.model, method) %>%
  summarise(AUC = mean(AUC)) %>%
  # filter(par.h2 == 0.8, par.dist == "gaussian", par.model == "simple") %>%
  filter(par.h2 == 0.8) %>%
  myggplot(aes(M, AUC, color = method)) +
  geom_line() + 
  geom_point(size = 2) +
  facet_grid(par.dist ~ par.model) +
  scale_y_continuous(breaks = 0:10 / 10, minor_breaks = c(0:9 + 0.5) / 10) +
  scale_x_sqrt(breaks = c(30, 300, 1200, 3000)) +
  labs(x = "Number of causal SNPs (sqrt-scale)")
AUC as function of M. For all chromosomes.

AUC as function of M. For all chromosomes.

results3 %>%
  filter(
    # method %in% c("logit-simple", "PRS-max"),
    grepl("all", par.causal)
  ) %>%
  mutate(M = readr::parse_number(par.causal)) %>%
  group_by(M, par.dist, par.h2, par.model, method) %>%
  summarise(AUC = mean(AUC)) %>%
  # filter(par.h2 == 0.8, par.dist == "gaussian", par.model == "simple") %>%
  filter(par.h2 == 0.8) %>%
  myggplot(aes(M, AUC, color = method)) +
  geom_line() + 
  geom_point(size = 2) +
  facet_grid(par.dist ~ par.model) +
  scale_y_continuous(breaks = 0:10 / 10, minor_breaks = c(0:9 + 0.5) / 10) +
  scale_x_sqrt(breaks = c(30, 300, 1200, 3000)) +
  labs(x = "Number of causal SNPs (sqrt-scale)")
AUC as function of M. For chromosome 6.

AUC as function of M. For chromosome 6.

bind_rows(
  bind_cols(results2, simu = rep("all", nrow(results2))),
  bind_cols(results3, simu = rep("chr6", nrow(results3)))
) %>%
  filter(par.h2 == 0.8, par.dist == "gaussian", method != "logit-triple") %>%
  myggplot() +
  geom_boxplot(aes(method, AUC, fill = simu, color = simu), alpha = 0.3) + 
  theme(axis.text.x = element_text(angle = 45, hjust = 1)) +
  facet_grid(par.model ~ par.causal) +
  theme(strip.text.x = element_text(size = rel(2)),
        strip.text.y = element_text(size = rel(2))) +
  scale_colour_brewer(palette = "Set1") +
  scale_fill_brewer(palette = "Set1") +
  geom_hline(yintercept = 0.94, color = "blue", linetype = 2)
AUCs for h2=0.8, dist=gaussian, comparing all chromosomes and chromosome 6.

AUCs for h2=0.8, dist=gaussian, comparing all chromosomes and chromosome 6.

knitr::include_graphics("effects.png")
Size of effects GWAS vs logistic, for Celiac

Size of effects GWAS vs logistic, for Celiac

knitr::include_graphics("preds-density2.png")
Density of scores from logistic regression by pop.

Density of scores from logistic regression by pop.

knitr::include_graphics("preds-density3.png")
Density of scores from logistic regression by pop and genotype.

Density of scores from logistic regression by pop and genotype.

knitr::include_graphics("perc_cases2.png")
Percent of controls (errors) in 199 (homoz bad in test).

Percent of controls (errors) in 199 (homoz bad in test).

knitr::include_graphics("gad-pred.png")
Projection on test set of score of Gad, train on same dataset

Projection on test set of score of Gad, train on same dataset